/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifdef PADDLE_WITH_ASCEND_CL
#include <fcntl.h>

#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif

#ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#endif

#include <ge/ge_api.h>
#include <graph/attr_value.h>
#include <graph/operator_factory.h>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/fleet/ascend_wrapper.h"
#include "paddle/fluid/platform/ascend_npu_info.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/pybind/ascend_wrapper_py.h"

using namespace ge;  // NOLINT
namespace py = pybind11;

namespace paddle {
namespace pybind {

#ifdef PADDLE_WITH_ASCEND_STRING
using AscendString = AscendString;
#else
using AscendString = std::string;
#endif

void BindAscendWrapper(py::module *m) {
  py::class_<framework::AscendInstance,
             std::shared_ptr<framework::AscendInstance>>(*m, "AscendInstance")
      .def(py::init([]() { return framework::AscendInstance::GetInstance(); }))
      .def("init_global_resources",
           &framework::AscendInstance::InitGlobalResouces,
           py::call_guard<py::gil_scoped_release>())
      .def("destroy_global_resources",
           &framework::AscendInstance::DestroyGlobalResouces,
           py::call_guard<py::gil_scoped_release>())
      .def("add_ascend_subgraph", &framework::AscendInstance::AddAscendSubgraph,
           py::call_guard<py::gil_scoped_release>());
}

std::map<AscendString, AscendString> convert_map(
    const std::map<std::string, std::string> &options) {
  std::map<AscendString, AscendString> rets;
  for (auto &option : options) {
    AscendString key = option.first.c_str();
    AscendString val = option.second.c_str();
    rets[key] = val;
  }
  return rets;
}

ge::Status ge_initialize(
    std::map<std::string, std::string> &options) {  // NOLINT
  py::gil_scoped_release release;
  auto init_options = convert_map(options);
  ge::Status res = ge::GEInitialize(init_options);
  PADDLE_ENFORCE_EQ(res, ge::SUCCESS, platform::errors::Fatal(
                                          "ge initialize not success:%d", res));
  py::gil_scoped_acquire acquire;
  return res;
}

enum AttrType {
  AT_INT64 = 0,
  AT_INT32,
  AT_UINT32,
  AT_LIST_INT64,
  AT_LIST_INT32,
  AT_LIST_UINT32,
  AT_FLOAT,
  AT_LIST_FLOAT,
  AT_ATTR_VALUE,
  AT_STRING,
  AT_LIST_STRING,
  AT_BOOL,
  AT_LIST_BOOL,
  AT_TENSOR,
  AT_LIST_TENSOR,
  AT_LIST_UINT8,
  AT_LIST_LIST_INT64,
  AT_LIST_DT,
  AT_DT,
  AT_LIST_NAMEATTR,
  AT_NAMEATTR
};

#ifdef PADDLE_WITH_ASCEND
void BindAscendDevice(py::module *m) {
  py::class_<platform::ascend::NPUDevice>(*m, "NPUDevice")
      .def_static(
          "get_device_count",
          static_cast<int (*)()>(&platform::ascend::NPUDevice::GetDeviceCount));
}
#endif

void BindAscendGraph(py::module *m) {
  m->def("ge_initialize", &ge_initialize, "GEInitialize");
  m->def("ge_finalize", &GEFinalize, "GEFinalize");

  // enum
  py::enum_<GraphRunMode>(*m, "GEGraphRunMode")
      .value("PREDICTION", GraphRunMode::PREDICTION)
      .value("TRAIN", GraphRunMode::TRAIN)
      .export_values();

  py::enum_<DataType>(*m, "GEDataType")
      .value("DT_FLOAT", DataType::DT_FLOAT)
      .value("DT_FLOAT16", DataType::DT_FLOAT16)
      .value("DT_INT8", DataType::DT_INT8)
      .value("DT_INT16", DataType::DT_INT16)
      .value("DT_UINT16", DataType::DT_UINT16)
      .value("DT_UINT8", DataType::DT_UINT8)
      .value("DT_INT32", DataType::DT_INT32)
      .value("DT_INT64", DataType::DT_INT64)
      .value("DT_UINT32", DataType::DT_UINT32)
      .value("DT_UINT64", DataType::DT_UINT64)
      .value("DT_BOOL", DataType::DT_BOOL)
      .value("DT_DOUBLE", DataType::DT_DOUBLE)
      .value("DT_STRING", DataType::DT_STRING)
      .value("DT_DUAL_SUB_INT8", DataType::DT_DUAL_SUB_INT8)
      .value("DT_DUAL_SUB_UINT8", DataType::DT_DUAL_SUB_UINT8)
      .value("DT_COMPLEX64", DataType::DT_COMPLEX64)
      .value("DT_COMPLEX128", DataType::DT_COMPLEX128)
      .value("DT_QINT8", DataType::DT_QINT8)
      .value("DT_QINT16", DataType::DT_QINT16)
      .value("DT_QINT32", DataType::DT_QINT32)
      .value("DT_QUINT8", DataType::DT_QUINT8)
      .value("DT_QUINT16", DataType::DT_QUINT16)
      .value("DT_RESOURCE", DataType::DT_RESOURCE)
      .value("DT_STRING_REF", DataType::DT_STRING_REF)
      .value("DT_DUAL", DataType::DT_DUAL)
      .value("DT_UNDEFINED", DataType::DT_UNDEFINED)
      .export_values();

  py::enum_<Format>(*m, "GEFormat")
      .value("FORMAT_NCHW", Format::FORMAT_NCHW)
      .value("FORMAT_NHWC", Format::FORMAT_NHWC)
      .value("FORMAT_ND", Format::FORMAT_ND)
      .value("FORMAT_NC1HWC0", Format::FORMAT_NC1HWC0)
      .value("FORMAT_FRACTAL_Z", Format::FORMAT_FRACTAL_Z)
      .value("FORMAT_NC1C0HWPAD", Format::FORMAT_NC1C0HWPAD)
      .value("FORMAT_NHWC1C0", Format::FORMAT_NHWC1C0)
      .value("FORMAT_FSR_NCHW", Format::FORMAT_FSR_NCHW)
      .value("FORMAT_FRACTAL_DECONV", Format::FORMAT_FRACTAL_DECONV)
      .value("FORMAT_C1HWNC0", Format::FORMAT_C1HWNC0)
      .value("FORMAT_FRACTAL_DECONV_TRANSPOSE",
             Format::FORMAT_FRACTAL_DECONV_TRANSPOSE)
      .value("FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS",
             Format::FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS)
      .value("FORMAT_NC1HWC0_C04", Format::FORMAT_NC1HWC0_C04)
      .value("FORMAT_FRACTAL_Z_C04", Format::FORMAT_FRACTAL_Z_C04)
      .value("FORMAT_CHWN", Format::FORMAT_CHWN)
      .value("FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS",
             Format::FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS)
      .value("FORMAT_HWCN", Format::FORMAT_HWCN)
      .value("FORMAT_NC1KHKWHWC0", Format::FORMAT_NC1KHKWHWC0)
      .value("FORMAT_BN_WEIGHT", Format::FORMAT_BN_WEIGHT)
      .value("FORMAT_FILTER_HWCK", Format::FORMAT_FILTER_HWCK)
      .value("FORMAT_HASHTABLE_LOOKUP_LOOKUPS",
             Format::FORMAT_HASHTABLE_LOOKUP_LOOKUPS)
      .value("FORMAT_HASHTABLE_LOOKUP_KEYS",
             Format::FORMAT_HASHTABLE_LOOKUP_KEYS)
      .value("FORMAT_HASHTABLE_LOOKUP_VALUE",
             Format::FORMAT_HASHTABLE_LOOKUP_VALUE)
      .value("FORMAT_HASHTABLE_LOOKUP_OUTPUT",
             Format::FORMAT_HASHTABLE_LOOKUP_OUTPUT)
      .value("FORMAT_HASHTABLE_LOOKUP_HITS",
             Format::FORMAT_HASHTABLE_LOOKUP_HITS)
      .value("FORMAT_C1HWNCoC0", Format::FORMAT_C1HWNCoC0)
      .value("FORMAT_MD", Format::FORMAT_MD)
      .value("FORMAT_NDHWC", Format::FORMAT_NDHWC)
      .value("FORMAT_FRACTAL_ZZ", Format::FORMAT_FRACTAL_ZZ)
      .value("FORMAT_FRACTAL_NZ", Format::FORMAT_FRACTAL_NZ)
      .value("FORMAT_NCDHW", Format::FORMAT_NCDHW)
      .value("FORMAT_DHWCN", Format::FORMAT_DHWCN)
      .value("FORMAT_NDC1HWC0", Format::FORMAT_NDC1HWC0)
      .value("FORMAT_FRACTAL_Z_3D", Format::FORMAT_FRACTAL_Z_3D)
      .value("FORMAT_CN", Format::FORMAT_CN)
      .value("FORMAT_NC", Format::FORMAT_NC)
      .value("FORMAT_DHWNC", Format::FORMAT_DHWNC)
      .value("FORMAT_FRACTAL_Z_3D_TRANSPOSE",
             Format::FORMAT_FRACTAL_Z_3D_TRANSPOSE)
      .value("FORMAT_FRACTAL_ZN_LSTM", Format::FORMAT_FRACTAL_ZN_LSTM)
      .value("FORMAT_FRACTAL_Z_G", Format::FORMAT_FRACTAL_Z_G)
      .value("FORMAT_RESERVED", Format::FORMAT_RESERVED)
      .value("FORMAT_ALL", Format::FORMAT_ALL)
      .value("FORMAT_NULL", Format::FORMAT_NULL)
      .export_values();

  py::enum_<UnknowShapeOpType>(*m, "GEUnknowShapeOpType")
      .value("DEPEND_IN_SHAPE", UnknowShapeOpType::DEPEND_IN_SHAPE)
      .value("DEPEND_CONST_VALUE", UnknowShapeOpType::DEPEND_CONST_VALUE)
      .value("DEPEND_SHAPE_RANGE", UnknowShapeOpType::DEPEND_SHAPE_RANGE)
      .value("DEPEND_COMPUTE", UnknowShapeOpType::DEPEND_COMPUTE)
      .export_values();

  py::enum_<DeviceType>(*m, "GEDeviceType")
      .value("NPU", DeviceType::NPU)
      .value("CPU", DeviceType::CPU)
      .export_values();

  py::enum_<AttrType>(*m, "GEAttrType")
      .value("AT_INT64", AttrType::AT_INT64)
      .value("AT_INT32", AttrType::AT_INT32)
      .value("AT_UINT32", AttrType::AT_UINT32)
      .value("AT_LIST_INT64", AttrType::AT_LIST_INT64)
      .value("AT_LIST_INT32", AttrType::AT_LIST_INT32)
      .value("AT_LIST_UINT32", AttrType::AT_LIST_UINT32)
      .value("AT_FLOAT", AttrType::AT_FLOAT)
      .value("AT_LIST_FLOAT", AttrType::AT_LIST_FLOAT)
      .value("AT_ATTR_VALUE", AttrType::AT_ATTR_VALUE)
      .value("AT_STRING", AttrType::AT_STRING)
      .value("AT_LIST_STRING", AttrType::AT_LIST_STRING)
      .value("AT_BOOL", AttrType::AT_BOOL)
      .value("AT_LIST_BOOL", AttrType::AT_LIST_BOOL)
      .value("AT_TENSOR", AttrType::AT_TENSOR)
      .value("AT_LIST_TENSOR", AttrType::AT_LIST_TENSOR)
      .value("AT_LIST_UINT8", AttrType::AT_LIST_UINT8)
      .value("AT_LIST_LIST_INT64", AttrType::AT_LIST_LIST_INT64)
      .value("AT_LIST_DT", AttrType::AT_LIST_DT)
      .value("AT_DT", AttrType::AT_DT)
      .value("AT_LIST_NAMEATTR", AttrType::AT_LIST_NAMEATTR)
      .value("AT_NAMEATTR", AttrType::AT_NAMEATTR)
      .export_values();

  // 类封装
  py::class_<Session>(*m, "GESession")
      .def(py::init([](const std::map<std::string, std::string> &options) {
        return std::unique_ptr<ge::Session>(
            new ge::Session(convert_map(options)));
      }))
      .def("add_graph", (ge::Status (Session::*)(uint32_t, const Graph &)) &
                            Session::AddGraph)
      .def("add_graph",
           [](Session &ss, uint32_t index, const Graph &graph,
              const std::map<std::string, std::string> &options) {
             return ss.AddGraph(index, graph, convert_map(options));
           })
      .def("remove_graph", &Session::RemoveGraph)
      .def("run_graph",
           [](Session &ss, uint32_t graphId,
              const std::vector<Tensor> &inputs) -> py::tuple {
             std::vector<Tensor> outputs;
             ge::Status res = ss.RunGraph(graphId, inputs, outputs);
             return py::make_tuple(outputs, res);
           },
           py::call_guard<py::gil_scoped_release>())
      .def("build_graph", &Session::BuildGraph)
      .def("run_graph_async", &Session::RunGraphAsync)
#ifdef PADDLE_WITH_ASCEND_STRING
      .def("register_call_back_func",
           static_cast<ge::Status (ge::Session::*)(  // NOLINT
               const char *, const ge::session::pCallBackFunc &)>(
               &ge::Session::RegisterCallBackFunc))
#else
      .def("register_call_back_func",
           (Status (Session::*)(  // NOLINT
               const std::string &,
               std::function<uint32_t(
                   uint32_t graph_id,
                   const std::map<std::string, ge::Tensor> &params_list)>)) &
               Session::RegisterCallBackFunc)
#endif
      .def("is_graph_need_rebuild", &Session::IsGraphNeedRebuild);

  py::class_<Graph>(*m, "GEGraph")
      .def(py::init<>())
      .def(py::init<const char *>())
      .def("set_inputs", &Graph::SetInputs)
      .def("set_outputs", (Graph & (Graph::*)(const std::vector<Operator> &)) &
                              Graph::SetOutputs)
      .def("set_outputs",
           (Graph & (Graph::*)(const std::vector<
                               std::pair<Operator, std::vector<size_t>>> &)) &
               Graph::SetOutputs)
      .def("set_outputs",
           (Graph &
            (Graph::*)(const std::vector<std::pair<ge::Operator, AscendString>>
                           &)) &
               Graph::SetOutputs)
      .def("set_targets", &Graph::SetTargets)
      .def("is_valid", &Graph::IsValid)
      .def("add_op", &Graph::AddOp)
      .def("find_op_by_name",
           [](Graph &graph, const char *name) -> py::tuple {
             ge::Operator op;
             graphStatus status = graph.FindOpByName(name, op);
             return py::make_tuple(op, status);
           })
      .def("find_op_by_type",
           [](Graph &graph, const char *type) -> py::tuple {
             std::vector<ge::Operator> ops;
             graphStatus status = graph.FindOpByType(type, ops);
             return py::make_tuple(ops, status);
           })
      .def("get_all_op_name",
           [](Graph &graph) -> py::tuple {
             std::vector<AscendString> op_name;
             graphStatus status = graph.GetAllOpName(op_name);
             return py::make_tuple(op_name, status);
           })
#ifdef PADDLE_WITH_ASCEND_STRING
      .def("save_to_file",
           static_cast<ge::graphStatus (ge::Graph::*)(const char *) const>(
               &ge::Graph::SaveToFile))
      .def("load_from_file",
           static_cast<ge::graphStatus (ge::Graph::*)(const char *)>(
               &Graph::LoadFromFile))
      .def("get_name",
           static_cast<ge::graphStatus (ge::Graph::*)(AscendString &) const>(
               &Graph::GetName))
#else
      .def("save_to_file", &Graph::SaveToFile)
      .def("load_from_file", &Graph::LoadFromFile)
      .def("get_name", &Graph::GetName)
#endif
      .def("set_need_iteration", &Graph::SetNeedIteration);

  py::class_<Operator>(*m, "GEOperator")
      .def(py::init<>())
      .def(py::init<const char *>())
      .def(py::init<const char *, const char *>())
      .def("is_empty", &Operator::IsEmpty)
#ifdef PADDLE_WITH_ASCEND_STRING
      .def("get_name",
           static_cast<ge::graphStatus (ge::Operator::*)(AscendString &) const>(
               &Operator::GetName))
      .def("get_op_type",
           static_cast<ge::graphStatus (ge::Operator::*)(AscendString &) const>(
               &Operator::GetOpType))
      .def("set_input",
           (Operator & (Operator::*)(const char *, const Operator &)) &
               Operator::SetInput)
      .def("set_input",
           (Operator &
            (Operator::*)(const char *, const Operator &, const char *)) &
               Operator::SetInput)
      .def("set_input", (Operator & (Operator::*)(const char *,
                                                  const Operator &, uint32_t)) &
                            Operator::SetInput)
#else
      .def("get_name", &Operator::GetName)
      .def("get_op_type", &Operator::GetOpType)
      .def("set_input",
           (Operator & (Operator::*)(const std::string &, const Operator &)) &
               Operator::SetInput)
      .def("set_input",
           (Operator & (Operator::*)(const std::string &, const Operator &,
                                     const std::string &)) &
               Operator::SetInput)
      .def("set_input", (Operator & (Operator::*)(const std::string &,
                                                  const Operator &, uint32_t)) &
                            Operator::SetInput)
#endif
      .def("add_control_input", &Operator::AddControlInput)
      .def("get_input_const_data",
           [](Operator &op, const char *dst_name) -> py::tuple {
             Tensor data;
             graphStatus res = op.GetInputConstData(dst_name, data);
             return py::make_tuple(data, res);
           })
#ifdef PADDLE_WITH_ASCEND_STRING
      .def("get_input_desc",
           (TensorDesc (Operator::*)(uint32_t) const) & Operator::GetInputDesc)
      .def("get_input_desc",
           [](Operator &op, const std::string &name) {
             return op.GetInputDescByName(name.c_str());
           })
      .def("get_dynamic_output_num",
           static_cast<int (ge::Operator::*)(const char *) const>(
               &Operator::GetDynamicOutputNum))
      .def("get_dynamic_input_num",
           static_cast<int (ge::Operator::*)(const char *) const>(
               &Operator::GetDynamicInputNum))
#else
      .def("get_input_desc",
           (TensorDesc (Operator::*)(const std::string &) const) &
               Operator::GetInputDesc)
      .def("get_input_desc",
           (TensorDesc (Operator::*)(uint32_t) const) & Operator::GetInputDesc)
      .def("get_dynamic_output_num", &Operator::GetDynamicOutputNum)
      .def("get_dynamic_input_num", &Operator::GetDynamicInputNum)
#endif
      .def("try_get_input_desc",
           [](Operator &op, const char *name) -> py::tuple {
             TensorDesc tensor_desc;
             graphStatus status = op.TryGetInputDesc(name, tensor_desc);
             return py::make_tuple(tensor_desc, status);
           })
#ifdef PADDLE_WITH_ASCEND_STRING
      .def("update_input_desc",
           static_cast<ge::graphStatus (ge::Operator::*)(  // NOLINT
               const char *, const TensorDesc &)>(&Operator::UpdateInputDesc))
      .def("get_output_desc",
           [](Operator &op, const std::string &name) {
             return op.GetOutputDescByName(name.c_str());
           })
      .def("get_output_desc",
           (TensorDesc (Operator::*)(uint32_t) const) & Operator::GetOutputDesc)
      .def("update_output_desc",
           static_cast<ge::graphStatus (ge::Operator::*)(  // NOLINT
               const char *, const TensorDesc &)>(&Operator::UpdateOutputDesc))
      .def("get_dynamic_input_desc",
           static_cast<ge::TensorDesc (ge::Operator::*)(const char *, uint32_t)
                           const>(&Operator::GetDynamicInputDesc))
      .def("update_dynamic_input_desc",
           static_cast<ge::graphStatus (ge::Operator::*)(const char *, uint32_t,
                                                         const TensorDesc &)>(
               &Operator::UpdateDynamicInputDesc))
      .def("get_dynamic_output_desc",
           static_cast<ge::TensorDesc (ge::Operator::*)(const char *, uint32_t)
                           const>(&Operator::GetDynamicOutputDesc))
      .def("update_dynamic_output_desc",
           static_cast<ge::graphStatus (ge::Operator::*)(const char *, uint32_t,
                                                         const TensorDesc &)>(
               &Operator::UpdateDynamicOutputDesc))
#else
      .def("update_input_desc", &Operator::UpdateInputDesc)
      .def("get_output_desc",
           (TensorDesc (Operator::*)(const std::string &) const) &
               Operator::GetOutputDesc)
      .def("get_output_desc",
           (TensorDesc (Operator::*)(uint32_t) const) & Operator::GetOutputDesc)
      .def("update_output_desc", &Operator::UpdateOutputDesc)
      .def("get_dynamic_input_desc", &Operator::GetDynamicInputDesc)
      .def("update_dynamic_input_desc", &Operator::UpdateDynamicInputDesc)
      .def("get_dynamic_output_desc", &Operator::GetDynamicOutputDesc)
      .def("update_dynamic_output_desc", &Operator::UpdateDynamicOutputDesc)
#endif
      .def("infer_shape_and_type", &Operator::InferShapeAndType)
      .def("set_inference_context", &Operator::SetInferenceContext)
      .def("get_inference_context", &Operator::GetInferenceContext)
      .def("verify_all_attr", &Operator::VerifyAllAttr)
      .def("get_inputs_size", &Operator::GetInputsSize)
      .def("get_outputs_size", &Operator::GetOutputsSize)
#ifdef PADDLE_WITH_ASCEND_STRING
      .def("get_all_attr_names_and_types",
           static_cast<ge::graphStatus (ge::Operator::*)(  // NOLINT
               std::map<AscendString, AscendString> &) const>(
               &Operator::GetAllAttrNamesAndTypes))
#else
      .def("get_all_attr_names_and_types", &Operator::GetAllAttrNamesAndTypes)
#endif
      .def("set_attr_int64",
           [](Operator &op, const char *name, int64_t value) -> Operator & {
             int64_t tar = (int64_t)value;
             return op.SetAttr(name, tar);
           })
      .def("set_attr_int32",
           [](Operator &op, const char *name, int32_t value) -> Operator & {
             int32_t tar = (int32_t)value;
             return op.SetAttr(name, tar);
           })
      .def("set_attr_uint32",
           [](Operator &op, const char *name, uint32_t value) -> Operator & {
             uint32_t tar = (uint32_t)value;
             return op.SetAttr(name, tar);
           })
      .def("set_attr_vec_int64",
           [](Operator &op, const char *name,
              const std::vector<int64_t> &value) -> Operator & {
             int len = value.size();
             std::vector<int64_t> tar;
             int64_t tmp;
             for (int i = 0; i < len; i++) {
               tmp = (int64_t)value[i];
               tar.push_back(tmp);
             }
             return op.SetAttr(name, tar);
           })
      .def("set_attr_vec_int32",
           [](Operator &op, const char *name,
              const std::vector<int32_t> &value) -> Operator & {
             int len = value.size();
             std::vector<int32_t> tar;
             int32_t tmp;
             for (int i = 0; i < len; i++) {
               tmp = (int32_t)value[i];
               tar.push_back(tmp);
             }
             return op.SetAttr(name, tar);
           })
      .def("set_attr_vec_uint32",
           [](Operator &op, const char *name,
              const std::vector<uint32_t> &value) -> Operator & {
             int len = value.size();
             std::vector<uint32_t> tar;
             uint32_t tmp;
             for (int i = 0; i < len; i++) {
               tmp = (uint32_t)value[i];
               tar.push_back(tmp);
             }
             return op.SetAttr(name, tar);
           })
      .def("set_attr_list_int64",
           [](Operator &op, const char *name,
              std::initializer_list<int64_t> &attrValue) -> Operator & {
             return op.SetAttr(name, std::move(attrValue));
           })
      .def("set_attr_attrvalue",
           [](Operator &op, const char *name, AttrValue &attrValue)
               -> Operator & { return op.SetAttr(name, std::move(attrValue)); })
      .def("set_attr_float",
           [](Operator &op, const char *name, float value) -> Operator & {
             float tar = static_cast<float>(value);
             return op.SetAttr(name, tar);
           })
      .def("set_attr_vec_float",
           [](Operator &op, const char *name,
              const std::vector<float> &value) -> Operator & {
             int len = value.size();
             std::vector<float> tar;
             float tmp;
             for (int i = 0; i < len; i++) {
               tmp = static_cast<float>(value[i]);
               tar.push_back(tmp);
             }
             return op.SetAttr(name, tar);
           })
#ifdef PADDLE_WITH_ASCEND_STRING
      .def("set_attr_string",
           (Operator & (Operator::*)(const char *, const char *)) &
               Operator::SetAttr)
      .def("set_attr_vec_string",
           (Operator &
            (Operator::*)(const char *, const std::vector<AscendString> &)) &
               Operator::SetAttr)
#else
      .def("set_attr_string", (Operator & (Operator::*)(const std::string &,
                                                        const std::string &)) &
                                  Operator::SetAttr)
      .def("set_attr_vec_string",
           (Operator & (Operator::*)(const std::string &,
                                     const std::vector<std::string> &)) &
               Operator::SetAttr)
#endif
      .def("set_attr_bool",
           [](Operator &op, const char *name, bool value) -> Operator & {
             if (value)
               return op.SetAttr(name, true);
             else
               return op.SetAttr(name, false);
           })
      .def("set_attr_vec_bool",
           [](Operator &op, const char *name,
              const std::vector<bool> &value) -> Operator & {
             int len = value.size();
             std::vector<bool> tar;
             for (int i = 0; i < len; i++) {
               if (value[i])
                 tar.push_back(true);
               else
                 tar.push_back(false);
             }
             return op.SetAttr(name, tar);
           })
#ifdef PADDLE_WITH_ASCEND_STRING
      .def("set_attr_tensor",
           (Operator & (Operator::*)(const char *, const Tensor &)) &
               Operator::SetAttr)
      .def("set_attr_vec_tensor",
           (Operator &
            (Operator::*)(const char *, const std::vector<Tensor> &)) &
               Operator::SetAttr)
#else
      .def("set_attr_tensor",
           (Operator & (Operator::*)(const std::string &, const Tensor &)) &
               Operator::SetAttr)
      .def("set_attr_vec_tensor",
           (Operator &
            (Operator::*)(const std::string &, const std::vector<Tensor> &)) &
               Operator::SetAttr)
#endif
      .def("set_attr_vec_uint8",
           [](Operator &op, const char *name,
              const std::vector<uint8_t> &value) -> Operator & {
             int len = value.size();
             std::vector<uint8_t> tar;
             uint8_t tmp;
             for (int i = 0; i < len; i++) {
               tmp = (uint8_t)value[i];
               tar.push_back(tmp);
             }
             return op.SetAttr(name, tar);
           })
#ifdef PADDLE_WITH_ASCEND_STRING
      .def("set_attr_vec_vec_int64",
           (Operator &
            (Operator::*)(const char *,
                          const std::vector<std::vector<int64_t>> &)) &
               Operator::SetAttr)
#else
      .def("set_attr_vec_vec_int64",
           (Operator &
            (Operator::*)(const std::string &,
                          const std::vector<std::vector<int64_t>> &)) &
               Operator::SetAttr)
#endif
      .def("set_attr_vec_dtype",
           [](Operator &op, const char *name,
              const std::vector<DataType> &value) -> Operator & {
             int len = value.size();
             std::vector<ge::DataType> tar;
             ge::DataType tmp;
             for (int i = 0; i < len; i++) {
               tmp = (ge::DataType)value[i];
               tar.push_back(tmp);
             }
             return op.SetAttr(name, tar);
           })
      .def("set_attr_dtype",
           [](Operator &op, const char *name,
              const DataType &value) -> Operator & {
             ge::DataType tar = (ge::DataType)value;
             return op.SetAttr(name, tar);
           })
      .def("get_attr",
           [](Operator &op, const char *name, AttrType type) -> py::tuple {
             graphStatus res = -1;
             switch (type) {
               case AT_INT64: {
                 int64_t i_64_av;
                 res = op.GetAttr(name, i_64_av);
                 return py::make_tuple(i_64_av, res);
               } break;
               case AT_INT32: {
                 int32_t i_32_av;
                 res = op.GetAttr(name, i_32_av);
                 return py::make_tuple(i_32_av, res);
               } break;
               case AT_UINT32: {
                 uint32_t ui_32_av;
                 res = op.GetAttr(name, ui_32_av);
                 return py::make_tuple(ui_32_av, res);
               } break;
               case AT_LIST_INT64: {
                 std::vector<int64_t> v_i_64_av;
                 res = op.GetAttr(name, v_i_64_av);
                 return py::make_tuple(v_i_64_av, res);
               } break;
               case AT_LIST_INT32: {
                 std::vector<int32_t> v_i_32_av;
                 res = op.GetAttr(name, v_i_32_av);
                 return py::make_tuple(v_i_32_av, res);
               } break;
               case AT_LIST_UINT32: {
                 std::vector<uint32_t> v_ui_32_av;
                 res = op.GetAttr(name, v_ui_32_av);
                 return py::make_tuple(v_ui_32_av, res);
               } break;
               case AT_FLOAT: {
                 float f_av;
                 res = op.GetAttr(name, f_av);
                 return py::make_tuple(f_av, res);
               } break;
               case AT_LIST_FLOAT: {
                 std::vector<float> v_f_av;
                 res = op.GetAttr(name, v_f_av);
                 return py::make_tuple(v_f_av, res);
               } break;
               case AT_ATTR_VALUE: {
                 AttrValue o_av;
                 res = op.GetAttr(name, o_av);
                 return py::make_tuple(o_av, res);
               } break;
               case AT_STRING: {
                 AscendString s_av;
                 res = op.GetAttr(name, s_av);
                 return py::make_tuple(s_av, res);
               } break;
               case AT_LIST_STRING: {
                 std::vector<AscendString> v_s_av;
                 res = op.GetAttr(name, v_s_av);
                 return py::make_tuple(v_s_av, res);
               } break;
               case AT_BOOL: {
                 bool b_av;
                 res = op.GetAttr(name, b_av);
                 return py::make_tuple(b_av, res);
               } break;
               case AT_LIST_BOOL: {
                 std::vector<bool> v_b_av;
                 res = op.GetAttr(name, v_b_av);
                 return py::make_tuple(v_b_av, res);
               } break;
               case AT_TENSOR: {
                 Tensor t_av;
                 res = op.GetAttr(name, t_av);
                 return py::make_tuple(t_av, res);
               } break;
               case AT_LIST_TENSOR: {
                 std::vector<Tensor> v_t_av;
                 res = op.GetAttr(name, v_t_av);
                 return py::make_tuple(v_t_av, res);
               } break;
               case AT_LIST_UINT8: {
                 std::vector<uint8_t> v_ui_8_av;
                 res = op.GetAttr(name, v_ui_8_av);
                 return py::make_tuple(v_ui_8_av, res);
               } break;
               case AT_LIST_LIST_INT64: {
                 std::vector<std::vector<int64_t>> v_v_i_64_av;
                 res = op.GetAttr(name, v_v_i_64_av);
                 return py::make_tuple(v_v_i_64_av, res);
               } break;
               case AT_DT: {
                 ge::DataType dt_av;
                 res = op.GetAttr(name, dt_av);
                 return py::make_tuple(dt_av, res);
               } break;
               case AT_LIST_DT: {
                 std::vector<ge::DataType> v_dt_av;
                 res = op.GetAttr(name, v_dt_av);
                 return py::make_tuple(v_dt_av, res);
               } break;
               default:
                 return py::make_tuple(0, res);
                 break;
             }
           })
      .def("break_connect", &Operator::BreakConnect)
      .def("get_subgraph_names_count", &Operator::GetSubgraphNamesCount)
#ifdef PADDLE_WITH_ASCEND_STRING
      .def("get_subgraph_names",
           static_cast<ge::graphStatus (ge::Operator::*)(  // NOLINT
               std::vector<AscendString> &) const>(&Operator::GetSubgraphNames))
      .def("get_subgraph_builder",
           static_cast<ge::SubgraphBuilder (ge::Operator::*)(const char *)
                           const>(&Operator::GetSubgraphBuilder))
      .def("get_subgraph",
           static_cast<ge::Graph (ge::Operator::*)(const char *) const>(
               &Operator::GetSubgraph))
      .def("get_dynamic_subgraph_builder",
           static_cast<ge::SubgraphBuilder (ge::Operator::*)(const char *,
                                                             uint32_t) const>(
               &Operator::GetDynamicSubgraphBuilder))
      .def("get_dynamic_subgraph",
           static_cast<ge::Graph (ge::Operator::*)(const char *, uint32_t)
                           const>(&Operator::GetDynamicSubgraph));
#else
      .def("get_subgraph_names_count", &Operator::GetSubgraphNamesCount)
      .def("get_subgraph_names", &Operator::GetSubgraphNames)
      .def("get_subgraph_builder", &Operator::GetSubgraphBuilder)
      .def("get_subgraph", &Operator::GetSubgraph)
      .def("get_dynamic_subgraph_builder", &Operator::GetDynamicSubgraphBuilder)
      .def("get_dynamic_subgraph", &Operator::GetDynamicSubgraph);
#endif

  py::class_<Tensor>(*m, "GETensor")
      .def(py::init<>())
      .def(py::init<const TensorDesc &>())
      .def(py::init<const TensorDesc &, const std::vector<uint8_t> &>())
      .def(py::init<const TensorDesc &, const uint8_t *, size_t>())
      .def("set_tensor_desc", &Tensor::SetTensorDesc)
      .def("get_tensor_desc", &Tensor::GetTensorDesc)
      // .def("set_data", (graphStatus(Tensor::*)(std::vector<uint8_t> &&)) &
      // Tensor::SetData)
      .def("set_data", (graphStatus (Tensor::*)(const std::vector<uint8_t> &)) &
                           Tensor::SetData)
      .def("set_data",
           (graphStatus (Tensor::*)(const uint8_t *, size_t)) & Tensor::SetData)
#ifdef PADDLE_WITH_ASCEND_STRING
      .def("set_data",
           (graphStatus (Tensor::*)(const char *)) & Tensor::SetData)
#else
      .def("set_data",
           (graphStatus (Tensor::*)(const std::string &)) & Tensor::SetData)
#endif
      .def("set_data",
           (graphStatus (Tensor::*)(const std::vector<AscendString> &)) &
               Tensor::SetData)

      .def("get_data",
           [](Tensor &ts) -> py::list {
             py::list v_data;
             uint8_t *data = ts.GetData();
             size_t size = ts.GetSize();
             for (size_t i = 0; i < size; ++i) {
               v_data.append(data[i]);
             }
             return v_data;
           })
      .def("get_size", &Tensor::GetSize)
      .def("is_valid", &Tensor::IsValid)
      .def("clone", &Tensor::Clone);

  py::class_<TensorDesc>(*m, "GETensorDesc")
      .def(py::init<>())
      .def(py::init<Shape, Format, DataType>(), py::arg("shape"),
           py::arg("format") = FORMAT_ND, py::arg("dt") = DT_FLOAT)
      .def(py::init<const TensorDesc &>())
      .def("update", (void (TensorDesc::*)(const Shape &, Format, DataType)) &
                         TensorDesc::Update,
           py::arg("shape"), py::arg("format") = FORMAT_ND,
           py::arg("dt") = DT_FLOAT)
      .def("set_shape", &TensorDesc::SetShape)
      .def("get_shape", &TensorDesc::GetShape)
      .def("set_unknown_dim_num_shape", &TensorDesc::SetUnknownDimNumShape)
      .def("set_shape_range", &TensorDesc::SetShapeRange)
      .def("get_shape_range",
           [](TensorDesc &tensorDesc) -> py::tuple {
             std::vector<std::pair<int64_t, int64_t>> range;
             graphStatus status = tensorDesc.GetShapeRange(range);
             return py::make_tuple(range, status);
           })
      .def("set_format", &TensorDesc::SetFormat)
      .def("get_format", &TensorDesc::GetFormat)
      .def("get_origin_shape", &TensorDesc::GetOriginShape)
      .def("set_origin_shape", &TensorDesc::SetOriginShape)
      .def("set_origin_format", &TensorDesc::SetOriginFormat)
      .def("get_origin_format", &TensorDesc::GetOriginFormat)
      .def("set_data_type", &TensorDesc::SetDataType)
      .def("get_data_type", &TensorDesc::GetDataType)
#ifdef PADDLE_WITH_ASCEND_STRING
      .def("set_name", static_cast<void (ge::TensorDesc::*)(const char *)>(
                           &TensorDesc::SetName))
      .def("get_name",
           static_cast<ge::graphStatus (ge::TensorDesc::*)(AscendString &)>(
               &TensorDesc::GetName))
#else
      .def("set_name", &TensorDesc::SetName)
      .def("get_name", &TensorDesc::GetName)
#endif
      .def("set_size", &TensorDesc::SetSize)
      .def("get_size", &TensorDesc::GetSize)
      .def("set_real_dim_cnt", &TensorDesc::SetRealDimCnt)
      .def("get_real_dim_cnt", &TensorDesc::GetRealDimCnt);

  py::class_<Shape>(*m, "GEShape")
      .def(py::init<>())
      .def(py::init<const std::vector<int64_t> &>())
      .def("get_dim_num", &Shape::GetDimNum)
      .def("set_dim", &Shape::SetDim)
      .def("get_dim", &Shape::GetDim)
      .def("get_dims", &Shape::GetDims)
      .def("get_shape_size", &Shape::GetShapeSize);

  py::class_<AttrValue>(*m, "GEAttrValue").def(py::init<>());

  py::class_<OperatorFactory>(*m, "GEOperatorFactory")
#ifdef PADDLE_WITH_ASCEND_STRING
      .def_static("create_operator",
                  static_cast<ge::Operator (*)(const char *, const char *)>(
                      &ge::OperatorFactory::CreateOperator))
#else
      .def("create_operator", &OperatorFactory::CreateOperator)
#endif
      .def("get_ops_type_list",
           []() -> py::tuple {
             std::vector<AscendString> all_ops;
             graphStatus status = OperatorFactory::GetOpsTypeList(all_ops);
             return py::make_tuple(all_ops, status);
           })
#ifdef PADDLE_WITH_ASCEND_STRING
      .def_static("is_exist_op", static_cast<bool (*)(const char *)>(
                                     &OperatorFactory::IsExistOp));
#else
      .def("is_exist_op", &OperatorFactory::IsExistOp);
#endif
}

}  // namespace pybind
}  // namespace paddle
#endif
